package web

import (
	"bytes"
	"encoding/json"
	_const "gitee.com/lstack_light/go-light/internal/pkg/global"
	"gitee.com/lstack_light/go-light/internal/pkg/utils"
	"gitee.com/lstack_light/go-light/internal/web/binding"
	"gitee.com/lstack_light/go-light/pkg/server/respData"
	"github.com/logrusorgru/aurora"
	"io"
	"io/ioutil"
	"net/http"
	"reflect"
	"strings"
	"sync"
)

/**
  @author: light
  @since: 2023/8/5
  @desc: 引擎
*/

type Engine struct {
	*methodTrees
	MiddleWares []MiddleWare
	Interceptor map[string][]string
}

var (
	lightEngine *Engine
	once        sync.Once
)

//
// InitEngine
//  @Description: 初始化框架引擎
//  @return *engine 框架引擎的指针
//
func InitEngine() *Engine {
	once.Do(func() {
		if nil == lightEngine {
			trees := methodTrees{methodTree: make(map[string]*tree), routerTable: make(map[string][]string)}
			lightEngine = &Engine{methodTrees: &trees, Interceptor: make(map[string][]string)}
		}
	})
	return lightEngine
}
func (e *Engine) ServeHTTP(wr http.ResponseWriter, r *http.Request) {
	var resp *respData.Response
	r.Response = &http.Response{}
	// 查找路由
	routerTree := e.getMethodTree(r.Method)
	if nil == routerTree {
		r.Response.StatusCode = http.StatusNotFound
		r.Response.Status = http.StatusText(http.StatusNotFound)
		wr.WriteHeader(http.StatusNotFound)
		return
	}
	paramMap := make(utils.Map)
	args, handler := routerTree.MatchingRouter(r.RequestURI, &paramMap)
	if !handler.IsValid() {
		r.Response.StatusCode = http.StatusNotFound
		r.Response.Status = http.StatusText(http.StatusNotFound)
		wr.WriteHeader(http.StatusNotFound)
		return
	}
	// 参数绑定
	bind := &binding.Bind{Request: r, SourceData: args}
	err := bind.BindParams(paramMap)
	if nil != err {
		r.Response.StatusCode = http.StatusInternalServerError
		r.Response.Status = "参数绑定异常" + err.Error()
		utils.GetLogger().Println(aurora.Yellow("参数绑定异常" + err.Error()))
		return
	}
	// 参数校验
	err = binding.NewValidator(args).Validate()
	if nil != err {
		msg := "参数校验异常" + err.Error()
		acceptLanguage := r.Header.Get(_const.AcceptLanguage)
		if strings.Contains(acceptLanguage, "en-US") {
			msg = "Parameter check exception" + err.Error()
		}
		resp = respData.SystemErrorResponse(msg)
	} else {
		// 执行处理器
		result := handler.Call(args)
		resp = result[0].Interface().(*respData.Response)
		// 处理文件下载
		if nil != resp.File {
			if 0 == len(resp.File.Content) {
				r.Response.StatusCode = http.StatusNotFound
				r.Response.Status = http.StatusText(http.StatusNotFound)
				wr.WriteHeader(http.StatusNotFound)
				return
			}
			wr.Header().Set("Content-Disposition", "attachment; filename="+resp.File.Name)
			if _, err = io.Copy(wr, bytes.NewReader(resp.File.Content)); nil == err {
				return
			}
			r.Response.StatusCode = http.StatusOK
			r.Response.Status = http.StatusText(http.StatusOK)
			wr.WriteHeader(http.StatusOK)
			return
		}
	}
	wr.WriteHeader(resp.HttpStatusCode)
	marshal, jsonErr := json.Marshal(&resp)
	if nil != jsonErr {
		utils.GetLogger().Println(aurora.Yellow("格式化结果异常" + jsonErr.Error()))
		return
	}
	r.Response.StatusCode = resp.HttpStatusCode
	r.Response.Status = http.StatusText(resp.HttpStatusCode)
	r.Response.Body = ioutil.NopCloser(bytes.NewBuffer(marshal))
	_, err = wr.Write(marshal)
	if nil != err {
		utils.GetLogger().Println(aurora.Yellow("写入结果异常" + err.Error()))
	}
}

func (e *Engine) SetControllers(controllers ...*Controller) http.Handler {
	for _, handler := range controllers {
		result := reflect.ValueOf(handler.module)
		methods := result.NumMethod()
		for i := 0; i < methods; i++ {
			var params []reflect.Value
			numIn := result.Method(i).Type().NumIn()
			for j := 0; j < numIn; j++ {
				param := result.Method(i).Type().In(j).Elem()
				params = append(params, reflect.New(param))
			}
			result.Method(i).Call(params)
		}
	}
	e.printSystemLogs()
	handler := e.ServeHTTP
	for i := len(e.MiddleWares) - 1; i >= 0; i-- {
		handler = e.MiddleWares[i].CoreHandler(handler, e.Interceptor)
	}
	http.HandleFunc(_const.TheSlashSeparator, func(writer http.ResponseWriter, request *http.Request) {
		handler(writer, request)
	})
	return nil
}

func (e *Engine) printSystemLogs() {
	logger := utils.GetLogger()
	table := e.getRouterTable()
	for method, routers := range table {
		for _, router := range routers {
			logger.Println(utils.BeautifyRequestMode(method), router)
		}
	}
}

type Handler func(http.ResponseWriter, *http.Request)

type MiddleWare struct {
	Key         string
	CoreHandler func(next Handler, interceptor map[string][]string) Handler
}

func (e *Engine) UseMiddleware(middleWares ...*MiddleWare) *Engine {
	for _, middleWare := range middleWares {
		exits := false
		for _, mw := range e.MiddleWares {
			if middleWare.Key == mw.Key {
				exits = true
				break
			}
		}
		if !exits {
			e.MiddleWares = append(e.MiddleWares, *middleWare)
		}
		e.Interceptor[middleWare.Key] = []string{_const.TheSlashSeparator}
	}
	return e
}
